import pandas as pd
import numpy as np
from sklearn.feature_extraction.text import CountVectorizer
import matplotlib.pyplot as plt
import seaborn as sns
from wordcloud import WordCloud
from vaderSentiment.vaderSentiment import SentimentIntensityAnalyzer
import tomotopy as tp
import re
import matplotlib.dates as mdates
import nltk

# Download NLTK resources (uncomment the line below if it's the first time using NLTK)
nltk.download('punkt')
nltk.download('stopwords')

from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords

from scipy.signal import savgol_filter



def get_data(country_names):
    # Read each CSV file into a DataFrame and add a 'country' column
    dfs = []
    for country_name in country_names:
        df = pd.read_excel(f"Data/{country_name}.xls")
        df['country'] = country_name
        dfs.append(df)

    # Concatenate all DataFrames into one
    data_org = pd.concat(dfs, ignore_index=True)
    return data_org

def plot_platform_count(data, col):
    # count number of speeches for each speaker
    platform_counts = data[col].value_counts()
    # sort speakers by number of speeches
    sorted_platform_counts = platform_counts.sort_values(ascending=True)
    # plot the result
    sorted_platform_counts.plot(kind='barh', color=sns.palettes.mpl_palette('Dark2'))
    plt.gca().spines[['top', 'right',]].set_visible(False)
    plt.title("Number of articles per platform")
    plt.xlabel("Number of articles")
    plt.ylabel("Platform")
    plt.show()

def plot_word_cloud(string):
    # create a WordCloud object
    wordcloud = WordCloud(background_color="white", max_words=10000, contour_width=3, contour_color='steelblue',width=800, height=400)
    # join the different speeches together
    long_string = ','.join(string)
    # generate a word cloud for speeches
    wordcloud.generate(long_string)
    # plot the word cloud
    plt.figure(figsize=(10, 8))
    plt.imshow(wordcloud, interpolation='bilinear')
    plt.axis("off")
    plt.show()

def plot_20_most_common_words(data, col_name, words_to_remove):
    # initialise the count vectorizer with the English stop words
    count_vectorizer = CountVectorizer(stop_words='english')
    # fit and transform the processed titles
    count_data = count_vectorizer.fit_transform(data[col_name])
    words = count_vectorizer.get_feature_names_out()
    total_counts = np.zeros(len(words))
    for t in count_data:
        total_counts+=t.toarray()[0]

    count_dict = (zip(words, total_counts))
    count_dict = sorted(count_dict, key=lambda x:x[1], reverse=True)[0:20]
    words = [w[0] for w in count_dict if w[0] not in words_to_remove][::-1]
    counts = [w[1] for w in count_dict if w[0] not in words_to_remove][::-1]
    x_pos = np.arange(len(words))

    # plt.figure(figsize=(6, 15))
    plt.barh(x_pos, counts, color=sns.color_palette('husl', len(words)))
    plt.title('20 most common words')
    plt.yticks(x_pos, words)
    plt.xlabel('Counts')
    plt.ylabel('Words')
    plt.show()

# Define a function for text preprocessing
def preprocess_text(text):
    # Tokenize the text into words
    words = word_tokenize(text.lower())  # Convert text to lowercase

    # Remove stop words
    stop_words = set(stopwords.words('english'))
    words = [word for word in words if word not in stop_words]

    # Remove punctuation and special characters
    words = [re.sub(r'[^\w\s]', '', word) for word in words]

    # Remove empty strings and single characters
    words = [word for word in words if len(word) > 1]

    # Join the words back into a single string
    # preprocessed_text = ' '.join(words)

    return words

def preprocess(data):
    preprocessed_data = data.copy()
    # Remove punctuation
    preprocessed_data["preprocessed_text"] = preprocessed_data["Full Text"].apply(preprocess_text)
    # preprocessed_data["cleaned_text"] = preprocessed_data["cleaned_text"].apply(lambda x: x.replace('say', ''))
    return preprocessed_data

# Function to get sentiment score using VADER
def get_sentiment_score(text):
    analyzer = SentimentIntensityAnalyzer()
    sentiment_score = analyzer.polarity_scores(text)['compound']
    return sentiment_score

# Function to get compound score
def get_compound_score(sid, text):
    L = []
    for sentence in text.split("."):
        comp = sid.polarity_scores(sentence)['compound']
        if comp != 0.0:
            L.append(comp)
    if len(L) == 0:
        return 0
    return sum(L) / len(L)

# Function to visualize sentiment trends
def visualize_sentiment_trends(dates, sentiment_scores, smoothed_scores):
    plt.figure(figsize=(10, 6))
    plt.plot(dates, sentiment_scores, label='Actual Sentiment Scores', marker='o', linestyle='-', color='b')
    plt.plot(dates, smoothed_scores, label='Smoothed Sentiment Scores', linestyle='-', color='r')
    plt.xlabel('Date')
    plt.ylabel('Sentiment Score')
    plt.title('Sentiment Trends Over Time')
    plt.legend()
    plt.grid(True)
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.show()

# Function to perform DMR topic modeling
def dmr_topic_modeling(text_data, metadata, iterations=100, prior_tp=None, word_clusters=None):
    mdl = tp.DMRModel(k=6, tw=tp.TermWeight.IDF, min_cf=5, min_df=5)
    # for doc, meta in zip(data, metadata):
        # mdl.add_doc(doc, metadata=str(meta))
    # if prior_tp:                 # Add prior topics and word clusters
        # for i, topic_words in enumerate(prior_tp):
            # mdl.prepare_topics([topic_words], labels=[f'topic_{i}'])
    if word_clusters:              # Add word priors
        for j, cluster_words in enumerate(word_clusters):
            for word in cluster_words:
                L = [0] * mdl.k
                L[j] = 1
                mdl.set_word_prior(word, L)
    for text, meta_data in zip(text_data, metadata):
        # doc = mdl.make_doc(text)
        # doc.sentiment = sentiment  # Assign sentiment score to each document
        # mdl.add_doc(doc)
        mdl.add_doc(text, metadata=str(meta_data))  # Add document with sentiment as metadata
    # mdl.train(100)
    mdl.train(0)
    for i in range(iterations):
        mdl.train(1)
        print('Iteration: {}\tLog-likelihood: {}'.format(i, mdl.ll_per_word))
    return mdl

# Function to visualize topic distributions
def visualize_topic_distributions(mdl):
    fig, ax = plt.subplots(figsize=(10, 6))
    topics = ['Topic {}'.format(i) for i in range(1, 7)]
    for i, topic in enumerate(topics):
        ax.plot(mdl.get_count_by_topics()[i] / mdl.get_count_by_topics().sum(), label=topic)
    ax.set_xlabel('Metadata')
    ax.set_ylabel('Topic Distribution')
    ax.set_title('Topic Distributions by Metadata')
    ax.legend()
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.show()

# Function to visualize positive and negative topics distribution
def plot_pos_neg_topic_dist(mdl, topics):
    # Aggregate topic distributions by sentiment
    positive_topics = []
    negative_topics = []
    for doc in mdl.docs:
        sentiment = float(doc.metadata)
        if sentiment >= 0:
            positive_topics.append(doc.get_topic_dist())
        else:
            negative_topics.append(doc.get_topic_dist())

    # Calculate percentages for topic distributions
    positive_topic_distribution = pd.DataFrame(positive_topics).mean().tolist()
    negative_topic_distribution = pd.DataFrame(negative_topics).mean().tolist()

    x = ["positive", "negative"]
    topic_dist_dict = {}
    for i in range(len(topics)):
        topic_dist_dict[f"topic_{i}"] = np.array([positive_topic_distribution[i], negative_topic_distribution[i]])*100
    
    fig, ax = plt.subplots(figsize=(10, 6))
    
    # plot bars in stack manner
    plt.bar(x, topic_dist_dict["topic_0"], label=topics[0])
    bottom_prior = topic_dist_dict["topic_0"]
    for i in range(1, len(topics)):
        plt.bar(x, topic_dist_dict[f"topic_{i}"], bottom=bottom_prior, label=topics[i])
        bottom_prior += topic_dist_dict[f"topic_{i}"]
    
    ax.set_xlabel('Sentiment')
    ax.set_ylabel('Percentage')
    ax.set_title('Distribution of Topics for Positive and Negative Sentiments')
    ax.legend(title='Topics', loc='upper left', bbox_to_anchor=(1, 1))
    plt.tight_layout()
    plt.show()

# Function to visualize topics distribution over time
def plot_topic_distribution(mdl, data, topics):
    topics_distribution = np.array([doc.get_topic_dist() for doc in mdl.docs])      # Extract topic distributions for each document
    dates = pd.to_datetime(data['pubdate'])
    df = pd.DataFrame({'Date': dates, 'Topics': list(topics_distribution)})         # Create a DataFrame with dates and topic distributions
    grouped = df.groupby('Date')['Topics'].apply(lambda x: np.mean(x, axis=0)).reset_index()     # Group by date and calculate mean topic distribution   
    # formatted_dates = mdates.date2num(grouped['Date'])                              # Convert dates to numeric format   
    compounded_proportions = np.cumprod(1 + np.array(list(grouped['Topics'])), axis=1) - 1        # Calculate compounded proportions
    normalized_proportions = (compounded_proportions.T / compounded_proportions.sum(axis=1)).T * 100  # Calculate percentage proportions

    # Plot percentage area plot
    plt.figure(figsize=(18, 6))
    plt.stackplot(grouped['Date'], normalized_proportions.T, labels=[topics[i] for i in range(topics_distribution.shape[1])], alpha=0.7)
    plt.xlabel('Date')
    plt.ylabel('Topic Proportion (%)')
    plt.title('Topic Proportions Over Time')
    plt.tick_params(axis='x', rotation=90)
    plt.legend()
    plt.show()

# Function to visualize topics smoothed distribution over time
def plot_smoothed_topic_distribution(mdl, data, topics):
    topics_distribution = np.array([doc.get_topic_dist() for doc in mdl.docs])  # Extract topic distributions for each document
    smoothed_topics_distribution = np.apply_along_axis(lambda x: np.convolve(x, np.ones(500)/500,mode="same"), axis=0, arr=topics_distribution)
    dates = pd.to_datetime(data['pubdate'])
    df = pd.DataFrame({'Date': dates, 'Topics': list(smoothed_topics_distribution)})  # Create a DataFrame with dates and smoothed topic distributions
    grouped = df.groupby('Date')['Topics'].apply(lambda x: np.mean(x, axis=0)).reset_index()  # Group by date and calculate mean smoothed topic distribution
    compounded_proportions = np.cumprod(1 + np.array(list(grouped['Topics'])), axis=1) - 1  # Calculate compounded proportions
    normalized_proportions = (compounded_proportions.T / compounded_proportions.sum(axis=1)).T * 100  # Calculate percentage proportions

    # Plot percentage area plot with smooth boundaries
    plt.figure(figsize=(18, 6))
    plt.stackplot(grouped['Date'], normalized_proportions.T, labels=[topics[i] for i in range(smoothed_topics_distribution.shape[1])], alpha=0.7, edgecolor='none')
    plt.xlabel('Date')
    plt.ylabel('Topic Proportion (%)')
    plt.title('Smoothed Topic Proportions Over Time')
    plt.tick_params(axis='x', rotation=90)
    plt.legend()
    plt.show()

# Function to visualize topics smoothed distribution over time (other approach)
def plot_smoothed_topic_distribution_2(mdl, data, topics):
    topics_distribution = np.array([doc.get_topic_dist() for doc in mdl.docs])      # Extract topic distributions for each document
    smoothed_topics_distribution = np.apply_along_axis(lambda x: np.convolve(x, np.ones(50)/50, mode='same'), axis=0, arr=topics_distribution)
    dates = pd.to_datetime(data['pubdate'])
    df = pd.DataFrame({'Date': dates, 'Topics': list(smoothed_topics_distribution)})         # Create a DataFrame with dates and topic distributions
    grouped = df.groupby('Date')['Topics'].apply(lambda x: np.mean(x, axis=0)).reset_index()     # Group by date and calculate mean topic distribution   
    # Select random rows to drop, excluding the first and last rows
    rows_to_drop = np.random.choice(grouped.index[1:-1], size=int(len(grouped) * 0.8), replace=False)
    grouped.drop(rows_to_drop, inplace=True)
    # formatted_dates = mdates.date2num(grouped['Date'])                              # Convert dates to numeric format   
    compounded_proportions = np.cumprod(1 + np.array(list(grouped['Topics'])), axis=1) - 1        # Calculate compounded proportions
    normalized_proportions = (compounded_proportions.T / compounded_proportions.sum(axis=1)).T * 100  # Calculate percentage proportions
    
    # Plot percentage area plot
    plt.figure(figsize=(18, 6))
    plt.stackplot(grouped['Date'], normalized_proportions.T, labels=[topics[i] for i in range(topics_distribution.shape[1])], alpha=0.7)
    plt.xlabel('Date')
    plt.ylabel('Topic Proportion (%)')
    plt.title('Topic Proportions Over Time')
    plt.tick_params(axis='x', rotation=90)
    plt.legend()
    plt.show()

# Function to visualize sentiment over time
def plot_sentiment_over_time(df, country_name=None):
    data = df.copy()
    if country_name:
        data = data[data["country"]==country_name]
    documents = data['Full Text'].tolist()
    analyzer = SentimentIntensityAnalyzer()                # Sentiment analysis with VADER
    # Get sentiment scores for each document
    sentiment_scores = []
    for doc in documents:
        # sentiment_score = analyzer.polarity_scores(doc)['compound']
        sentiment_score = get_compound_score(analyzer, doc)
        sentiment_scores.append(sentiment_score)
    data['Sentiment Score'] = sentiment_scores             # Add sentiment scores to the DataFrame
    data['pubdate'] = pd.to_datetime(data['pubdate'])      # Convert 'pubdate' column to datetime format
    data.sort_values(by='pubdate', inplace=True)           # Sort data by date
    smoothed_scores = savgol_filter(data['Sentiment Score'], window_length=11, polyorder=3)   # Smooth the sentiment scores using Savitzky-Golay filter

    # Plot sum of daily sentiment score
    plt.figure(figsize=(18, 6))
    plt.scatter(data['pubdate'], data['Sentiment Score'], label='Sentiment Score (Sum)', alpha=0.5)
    plt.plot(data['pubdate'], smoothed_scores, color='purple', label='Smooth', linewidth=2)
    plt.xlabel('Date')
    plt.ylabel('Sentiment Score')
    if country_name:
        plt.title(f'Sum of sentiment score for {country_name}')
    else:
        plt.title('Sum of sentiment score')
    plt.legend()
    plt.tight_layout()
    plt.show()

    mean_sentiment_score = data['Sentiment Score'].mean()         # Calculate mean of daily sentiment score

    # Plot mean of daily sentiment score
    plt.figure(figsize=(18, 6))
    plt.plot(data['pubdate'], data['Sentiment Score'], label='Sentiment Score (Raw)', alpha=0.5)
    plt.hlines(mean_sentiment_score, data['pubdate'].min(), data['pubdate'].max(), color='purple', label=f'Mean Score: {mean_sentiment_score:.3f}', linestyle='--')
    plt.plot(data['pubdate'], smoothed_scores, color='blue', label='Smooth', linewidth=2)
    plt.xlabel('Date')
    plt.ylabel('Sentiment Score')
    if country_name:
        plt.title(f'Mean of sentiment score for {country_name}')
    else:
        plt.title('Mean of sentiment score')
    plt.legend()
    plt.tight_layout()
    plt.show()

# Function to visualize sentiment over time
def plot_sentiment_over_time_2(data_temp_sentiment, country_name=None):
    # Sum of sentiment score
    data_temp_sum = data_temp_sentiment.groupby("pubdate", as_index=False)["sentiment_score"].sum()
    # Smooth the sum of sentiment score
    smoothed_sum = savgol_filter(data_temp_sum['sentiment_score'], window_length=15, polyorder=3)
    plt.figure(figsize=(18, 6))
    plt.scatter(data_temp_sum['pubdate'], data_temp_sum['sentiment_score'], label='Sum of daily sentiment score', alpha=0.5)
    plt.plot(data_temp_sum['pubdate'], smoothed_sum, color='purple', label='Smooth', linewidth=2)
    plt.xlabel('Date')
    plt.ylabel('Sentiment Score')
    if country_name:
        plt.title(f'Sum of daily sentiment score for {country_name}')
    else:
        plt.title('Sum of daily sentiment score')
    plt.legend()
    plt.tight_layout()
    plt.savefig(f'results/sum_of_sentiment_score_{country_name}.png')  # Save the plot as an image
    plt.show()
    # Mean of sentiment score
    data_temp_mean = data_temp_sentiment.groupby("pubdate")["sentiment_score"].mean().reset_index()
    # Smooth the mean of sentiment score
    smoothed_mean = savgol_filter(data_temp_mean['sentiment_score'], window_length=15, polyorder=3)
    plt.figure(figsize=(18, 6))
    plt.scatter(data_temp_mean['pubdate'], data_temp_mean['sentiment_score'], label='Mean of daily sentiment score', alpha=0.5)
    plt.plot(data_temp_mean['pubdate'], smoothed_mean, color='purple', label='Smooth', linewidth=2)
    plt.xlabel('Date')
    plt.ylabel('Sentiment Score')
    if country_name:
        plt.title(f'Mean of daily sentiment score for {country_name}')
    else:
        plt.title('Mean of daily sentiment score')
    plt.legend()
    plt.tight_layout()
    plt.savefig(f'results/mean_of_sentiment_score_{country_name}.png')  # Save the plot as an image
    plt.show()

def plot_sentiment_articles(data, threshold, country_name=None):
    positive_scores = data[data['sentiment_score'] > threshold + 0.05]['sentiment_score']
    negative_scores = data[data['sentiment_score'] <= threshold]['sentiment_score']

    # Plot histogram with two colors
    plt.figure(figsize=(8, 6))
    plt.hist([positive_scores, negative_scores], bins=20, color=['lightblue', 'lightcoral'], label=['Positive', 'Negative'], stacked=False, edgecolor='none', rwidth=10)
    plt.xlabel('Sentiment Score')
    plt.ylabel('Number of Articles')
    if country_name:
        plt.title(f'Number of Articles for Each Sentiment Score for {country_name}')
    else:
        plt.title('Number of Articles for Each Sentiment Score')
    plt.legend()
    plt.tight_layout()
    plt.show()